# roc_analysis.R
# Generate ROC curves for BERT violence classifier
# 1. BERT vs LLM consensus (ground truth validation)
# 2. BERT vs Dictionary (method comparison)

library(tidyverse)
library(pROC)

# ============================================================
# CONFIGURATION
# ============================================================

BERT_SCORES_FILE <- "vpi_scored_comments.csv"
DICTIONARY_SCORES_FILE <- "~/Downloads/comments_score.rds"
LLM_RESULTS_DIR <- "multi_model_results"
OUTPUT_DIR <- "roc_output"

dir.create(OUTPUT_DIR, showWarnings = FALSE)

# ============================================================
# LOAD DATA
# ============================================================

cat("Loading BERT scores...\n")
bert_full <- read_csv(
  BERT_SCORES_FILE,
  col_types = cols(
    comment_id = col_character(),
    text_original = col_character(),
    vpi_binary_label = col_double(),
    vpi_binary_confidence = col_double(),
    vpi_intensity_score = col_double(),
    source_db = col_character()
  )
)
cat("  Loaded:", nrow(bert_full), "comments\n")

cat("Loading dictionary scores...\n")
dict_data <- readRDS(DICTIONARY_SCORES_FILE) %>%
  select(comment_id, scalar_sum) %>%
  mutate(dict_binary = as.integer(scalar_sum >= 1))
cat("  Loaded:", nrow(dict_data), "comments\n")

# ============================================================
# ROC 1: BERT vs LLM CONSENSUS (Ground Truth)
# ============================================================

cat("\n--- ROC 1: BERT vs LLM Consensus ---\n")

# Load LLM-labeled data
llm_files <- list.files(LLM_RESULTS_DIR, pattern = "ALL_MODELS_COMPARISON.csv", full.names = TRUE)

if (length(llm_files) > 0) {
  
  llm_data <- map_df(llm_files, read_csv, show_col_types = FALSE)
  cat("  Loaded LLM labels:", nrow(llm_data), "comments\n")
  
  # Create consensus ground truth (majority vote: ≥2 of 4 LLMs)
  violence_cols <- names(llm_data)[str_detect(names(llm_data), "^discusses_violence_")]
  cat("  Violence columns found:", paste(violence_cols, collapse = ", "), "\n")
  
  llm_labeled <- llm_data %>%
    rowwise() %>%
    mutate(
      violence_votes = sum(c_across(all_of(violence_cols)), na.rm = TRUE),
      ground_truth = as.integer(violence_votes >= 2)
    ) %>%
    ungroup() %>%
    select(comment_id, ground_truth, violence_votes)
  
  cat("  Ground truth distribution:\n")
  print(table(llm_labeled$ground_truth))
  
  # Join with BERT predictions
  roc_data_llm <- bert_full %>%
    inner_join(llm_labeled, by = "comment_id")
  
  cat("  Matched comments:", nrow(roc_data_llm), "\n")
  
  if (nrow(roc_data_llm) > 0 && length(unique(roc_data_llm$ground_truth)) == 2) {
    
    # Compute ROC
    roc_llm <- roc(roc_data_llm$ground_truth, roc_data_llm$vpi_binary_confidence)
    auc_llm <- auc(roc_llm)
    
    # Get optimal threshold
    coords_best <- coords(roc_llm, "best", ret = c("threshold", "sensitivity", "specificity"))
    
    cat("  AUC:", round(auc_llm, 3), "\n")
    cat("  Optimal threshold:", round(coords_best$threshold, 3), "\n")
    cat("  Sensitivity:", round(coords_best$sensitivity, 3), "\n")
    cat("  Specificity:", round(coords_best$specificity, 3), "\n")
    
    # Plot
    pdf(file.path(OUTPUT_DIR, "roc_bert_vs_llm_consensus.pdf"), width = 7, height = 6)
    par(mar = c(5, 4, 4, 2))
    
    plot(roc_llm,
         main = "BERT Violence Classifier: ROC Curve",
         col = "steelblue",
         lwd = 2.5,
         legacy.axes = TRUE,
         xlab = "False Positive Rate (1 - Specificity)",
         ylab = "True Positive Rate (Sensitivity)")
    
    # Add subtitle inside plot area
    mtext(paste0("Ground Truth: LLM Consensus, n = ", nrow(roc_data_llm)), 
          side = 3, line = 0.3, cex = 0.9)
    
    abline(a = 0, b = 1, lty = 2, col = "gray50")
    legend("bottomright",
           legend = c(paste0("AUC = ", round(auc_llm, 3)),
                      paste0("Optimal threshold = ", round(coords_best$threshold, 3))),
           bty = "n")
    dev.off()
    
    cat("  Saved: roc_bert_vs_llm_consensus.pdf\n")
    
  } else {
    cat("  ERROR: Not enough data or only one class present\n")
    roc_llm <- NULL
    auc_llm <- NA
  }
  
} else {
  cat("  WARNING: No LLM comparison files found in", LLM_RESULTS_DIR, "\n")
  roc_llm <- NULL
  auc_llm <- NA
}

# ============================================================
# ROC 2: BERT vs DICTIONARY
# ============================================================

cat("\n--- ROC 2: BERT vs Dictionary ---\n")

# Merge BERT and dictionary
roc_data_dict <- bert_full %>%
  inner_join(dict_data, by = "comment_id")

cat("  Matched comments:", nrow(roc_data_dict), "\n")
cat("  Dictionary positive rate:", round(mean(roc_data_dict$dict_binary) * 100, 2), "%\n")

# Compute ROC (BERT confidence predicting dictionary labels)
roc_dict <- roc(roc_data_dict$dict_binary, roc_data_dict$vpi_binary_confidence)
auc_dict <- auc(roc_dict)

cat("  AUC:", round(auc_dict, 3), "\n")

# Plot
png(file.path(OUTPUT_DIR, "roc_bert_vs_dictionary.png"), width = 700, height = 600, res = 100)
plot(roc_dict,
     main = "BERT Predictions vs Dictionary Labels: ROC Curve",
     sub = paste0("Reference: Dictionary (scalar_sum ≥ 1), n = ", scales::comma(nrow(roc_data_dict))),
     col = "darkorange",
     lwd = 2.5,
     legacy.axes = TRUE,
     xlab = "False Positive Rate (1 - Specificity)",
     ylab = "True Positive Rate (Sensitivity)")
abline(a = 0, b = 1, lty = 2, col = "gray50")
legend("bottomright",
       legend = paste0("AUC = ", round(auc_dict, 3)),
       bty = "n")
dev.off()

cat("  Saved: roc_bert_vs_dictionary.png\n")

# ============================================================
# COMBINED PLOT (if both available)
# ============================================================

if (!is.null(roc_llm)) {
  
  cat("\n--- Combined ROC Plot ---\n")
  
  png(file.path(OUTPUT_DIR, "roc_combined.png"), width = 800, height = 600, res = 100)
  
  plot(roc_llm,
       main = "BERT Violence Classifier: ROC Comparison",
       col = "steelblue",
       lwd = 2.5,
       legacy.axes = TRUE,
       xlab = "False Positive Rate (1 - Specificity)",
       ylab = "True Positive Rate (Sensitivity)")
  
  plot(roc_dict, add = TRUE, col = "darkorange", lwd = 2.5)
  
  abline(a = 0, b = 1, lty = 2, col = "gray50")
  
  legend("bottomright",
         legend = c(paste0("vs LLM Consensus (AUC = ", round(auc_llm, 3), ", n = ", nrow(roc_data_llm), ")"),
                    paste0("vs Dictionary (AUC = ", round(auc_dict, 3), ", n = ", scales::comma(nrow(roc_data_dict)), ")")),
         col = c("steelblue", "darkorange"),
         lwd = 2.5,
         bty = "n")
  
  dev.off()
  
  cat("  Saved: roc_combined.png\n")
}

# ============================================================
# SUMMARY
# ============================================================

cat("\n============================================================\n")
cat("ROC ANALYSIS SUMMARY\n")
cat("============================================================\n")

if (!is.null(roc_llm)) {
  cat("BERT vs LLM Consensus:\n")
  cat("  AUC:", round(auc_llm, 3), "\n")
  cat("  n:", nrow(roc_data_llm), "\n\n")
}

cat("BERT vs Dictionary:\n")
cat("  AUC:", round(auc_dict, 3), "\n")
cat("  n:", scales::comma(nrow(roc_data_dict)), "\n\n")

cat("Output files saved to:", OUTPUT_DIR, "\n")